git_adr/ai/
provider.rs

1//! AI provider abstraction.
2
3use crate::Error;
4
5/// Supported AI providers.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AiProvider {
8    /// Anthropic Claude.
9    Anthropic,
10    /// OpenAI GPT.
11    OpenAi,
12    /// Google Gemini.
13    Google,
14    /// Local Ollama.
15    Ollama,
16}
17
18impl std::fmt::Display for AiProvider {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            Self::Anthropic => write!(f, "anthropic"),
22            Self::OpenAi => write!(f, "openai"),
23            Self::Google => write!(f, "google"),
24            Self::Ollama => write!(f, "ollama"),
25        }
26    }
27}
28
29impl std::str::FromStr for AiProvider {
30    type Err = Error;
31
32    fn from_str(s: &str) -> Result<Self, Self::Err> {
33        match s.to_lowercase().as_str() {
34            "anthropic" | "claude" => Ok(Self::Anthropic),
35            "openai" | "gpt" => Ok(Self::OpenAi),
36            "google" | "gemini" => Ok(Self::Google),
37            "ollama" | "local" => Ok(Self::Ollama),
38            _ => Err(Error::InvalidProvider {
39                provider: s.to_string(),
40            }),
41        }
42    }
43}
44
45/// Configuration for an AI provider.
46#[derive(Debug, Clone)]
47pub struct ProviderConfig {
48    /// The provider to use.
49    pub provider: AiProvider,
50    /// Model name (provider-specific).
51    pub model: String,
52    /// API key or endpoint.
53    pub api_key: Option<String>,
54    /// Base URL for the API.
55    pub base_url: Option<String>,
56    /// Temperature for generation.
57    pub temperature: f32,
58    /// Maximum tokens to generate.
59    pub max_tokens: u32,
60}
61
62impl Default for ProviderConfig {
63    fn default() -> Self {
64        Self {
65            provider: AiProvider::Anthropic,
66            model: "claude-3-haiku-20240307".to_string(),
67            api_key: None,
68            base_url: None,
69            temperature: 0.7,
70            max_tokens: 2048,
71        }
72    }
73}
74
75impl ProviderConfig {
76    /// Create a new config for the given provider.
77    #[must_use]
78    pub fn new(provider: AiProvider) -> Self {
79        let model = match provider {
80            AiProvider::Anthropic => "claude-3-haiku-20240307".to_string(),
81            AiProvider::OpenAi => "gpt-4o-mini".to_string(),
82            AiProvider::Google => "gemini-1.5-flash".to_string(),
83            AiProvider::Ollama => "llama3.2".to_string(),
84        };
85
86        Self {
87            provider,
88            model,
89            ..Default::default()
90        }
91    }
92
93    /// Set the model.
94    #[must_use]
95    pub fn with_model(mut self, model: impl Into<String>) -> Self {
96        self.model = model.into();
97        self
98    }
99
100    /// Set the API key.
101    #[must_use]
102    pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
103        self.api_key = Some(key.into());
104        self
105    }
106
107    /// Set the base URL.
108    #[must_use]
109    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
110        self.base_url = Some(url.into());
111        self
112    }
113
114    /// Set the temperature.
115    #[must_use]
116    pub fn with_temperature(mut self, temp: f32) -> Self {
117        self.temperature = temp;
118        self
119    }
120
121    /// Get the API key from config or environment.
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if no API key is available.
126    pub fn get_api_key(&self) -> Result<String, Error> {
127        if let Some(key) = &self.api_key {
128            return Ok(key.clone());
129        }
130
131        let env_var = match self.provider {
132            AiProvider::Anthropic => "ANTHROPIC_API_KEY",
133            AiProvider::OpenAi => "OPENAI_API_KEY",
134            AiProvider::Google => "GOOGLE_API_KEY",
135            AiProvider::Ollama => return Ok(String::new()), // Ollama doesn't need a key
136        };
137
138        std::env::var(env_var).map_err(|_| Error::AiNotConfigured {
139            message: format!("{env_var} not set"),
140        })
141    }
142
143    /// Get the base URL from config or environment.
144    #[must_use]
145    pub fn get_base_url(&self) -> Option<String> {
146        if let Some(url) = &self.base_url {
147            return Some(url.clone());
148        }
149
150        match self.provider {
151            AiProvider::Ollama => Some(
152                std::env::var("OLLAMA_HOST")
153                    .unwrap_or_else(|_| "http://localhost:11434".to_string()),
154            ),
155            _ => None,
156        }
157    }
158}