Skip to content

Commit 6a2bc47

Browse files
author
Joevenner
committed
feat: Add multi-provider LLM support via LiteLLM integration
Replace OpenAI-only implementation with LiteLLM to support 100+ LLM providers including Anthropic Claude, Google Gemini, Azure OpenAI, AWS Bedrock, Groq, and local Ollama models. Changes: - Add litellm>=1.0.0 dependency - Refactor ChatGPT_API functions to use litellm.completion() - Enhance count_tokens() for multi-provider token counting - Update config.yaml with provider-specific model examples - Update README.md with multi-provider setup instructions Backward compatible: Existing OPENAI_API_KEY and CHATGPT_API_KEY still work. Default model remains gpt-4o-2024-11-20.
1 parent a061d53 commit 6a2bc47

4 files changed

Lines changed: 197 additions & 39 deletions

File tree

README.md

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,14 +147,49 @@ You can follow these steps to generate a PageIndex tree from a PDF document.
147147
pip3 install --upgrade -r requirements.txt
148148
```
149149

150-
### 2. Set your OpenAI API key
150+
### 2. Set your API key
151151

152-
Create a `.env` file in the root directory and add your API key:
152+
PageIndex now supports multiple LLM providers via [LiteLLM](https://docs.litellm.ai/). Create a `.env` file in the root directory and add your API key:
153153

154+
**OpenAI (default):**
154155
```bash
156+
OPENAI_API_KEY=your_openai_key_here
157+
# or
155158
CHATGPT_API_KEY=your_openai_key_here
156159
```
157160

161+
**Anthropic Claude:**
162+
```bash
163+
ANTHROPIC_API_KEY=your_anthropic_key_here
164+
```
165+
166+
**Google Gemini:**
167+
```bash
168+
GEMINI_API_KEY=your_google_key_here
169+
```
170+
171+
**Azure OpenAI:**
172+
```bash
173+
AZURE_API_KEY=your_azure_key_here
174+
AZURE_API_BASE=your_azure_endpoint
175+
AZURE_API_VERSION=2024-02-01
176+
```
177+
178+
**AWS Bedrock:**
179+
```bash
180+
AWS_ACCESS_KEY_ID=your_access_key
181+
AWS_SECRET_ACCESS_KEY=your_secret_key
182+
AWS_REGION_NAME=us-east-1
183+
```
184+
185+
**Groq:**
186+
```bash
187+
GROQ_API_KEY=your_groq_key_here
188+
```
189+
190+
**Ollama (local):**
191+
No API key needed. Just ensure Ollama is running locally.
192+
158193
### 3. Run PageIndex on your PDF
159194

160195
```bash
@@ -167,7 +202,15 @@ python3 run_pageindex.py --pdf_path /path/to/your/document.pdf
167202
You can customize the processing with additional optional arguments:
168203

169204
```
170-
--model OpenAI model to use (default: gpt-4o-2024-11-20)
205+
--model LLM model to use (default: gpt-4o-2024-11-20)
206+
Examples:
207+
- OpenAI: gpt-4o, gpt-4-turbo
208+
- Claude: claude-3-opus-20240229, claude-3-sonnet-20240229
209+
- Gemini: gemini/gemini-pro, gemini/gemini-1.5-pro
210+
- Azure: azure/your-deployment-name
211+
- Bedrock: bedrock/anthropic.claude-3-opus-20240229-v1:0
212+
- Groq: groq/llama-3.1-70b-versatile
213+
- Ollama: ollama/llama3
171214
--toc-check-pages Pages to check for table of contents (default: 20)
172215
--max-pages-per-node Max pages per node (default: 10)
173216
--max-tokens-per-node Max tokens per node (default: 20000)

pageindex/config.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,38 @@
1+
# PageIndex Configuration
2+
#
3+
# Model Configuration:
4+
# PageIndex now supports multiple LLM providers via LiteLLM.
5+
# Set the model string according to your preferred provider:
6+
#
7+
# OpenAI (default):
8+
# model: "gpt-4o-2024-11-20" or "gpt-4o" or "gpt-4-turbo"
9+
# Env var: OPENAI_API_KEY or CHATGPT_API_KEY
10+
#
11+
# Anthropic Claude:
12+
# model: "claude-3-opus-20240229" or "claude-3-sonnet-20240229"
13+
# Env var: ANTHROPIC_API_KEY
14+
#
15+
# Google Gemini:
16+
# model: "gemini/gemini-pro" or "gemini/gemini-1.5-pro"
17+
# Env var: GEMINI_API_KEY
18+
#
19+
# Azure OpenAI:
20+
# model: "azure/your-deployment-name"
21+
# Env vars: AZURE_API_KEY, AZURE_API_BASE, AZURE_API_VERSION
22+
#
23+
# AWS Bedrock:
24+
# model: "bedrock/anthropic.claude-3-opus-20240229-v1:0"
25+
# Env vars: AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION_NAME
26+
#
27+
# Groq:
28+
# model: "groq/llama-3.1-70b-versatile"
29+
# Env var: GROQ_API_KEY
30+
#
31+
# Ollama (local):
32+
# model: "ollama/llama3"
33+
#
34+
# For more providers, see: https://docs.litellm.ai/docs/providers
35+
136
model: "gpt-4o-2024-11-20"
237
toc_check_page_num: 20
338
max_page_num_each_node: 10

pageindex/utils.py

Lines changed: 114 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import tiktoken
2-
import openai
2+
import litellm
33
import logging
44
import os
55
from datetime import datetime
@@ -17,30 +17,82 @@
1717
from pathlib import Path
1818
from types import SimpleNamespace as config
1919

20-
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")
20+
# Support multiple API key environment variables for different providers
21+
CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") or os.getenv("OPENAI_API_KEY")
22+
23+
# Configure LiteLLM to use environment variables for different providers
24+
# Users can set: OPENAI_API_KEY, ANTHROPIC_API_KEY, GEMINI_API_KEY, etc.
25+
# See: https://docs.litellm.ai/docs/providers
2126

2227
def count_tokens(text, model=None):
28+
"""
29+
Count tokens in text using the appropriate tokenizer for the model.
30+
Uses tiktoken for OpenAI models and LiteLLM's token counter for other providers.
31+
"""
2332
if not text:
2433
return 0
25-
enc = tiktoken.encoding_for_model(model)
26-
tokens = enc.encode(text)
27-
return len(tokens)
34+
35+
# Check if it's an OpenAI model (no prefix or openai/ prefix)
36+
model_lower = model.lower() if model else ""
37+
is_openai_model = (
38+
not "/" in model or
39+
model_lower.startswith("openai/") or
40+
model_lower.startswith("gpt-") or
41+
model_lower.startswith("o1-") or
42+
model_lower.startswith("o3-")
43+
)
44+
45+
if is_openai_model:
46+
# Use tiktoken for OpenAI models
47+
try:
48+
# Strip openai/ prefix if present
49+
clean_model = model.replace("openai/", "") if model else "gpt-4o"
50+
enc = tiktoken.encoding_for_model(clean_model)
51+
tokens = enc.encode(text)
52+
return len(tokens)
53+
except KeyError:
54+
# Fallback to cl100k_base encoding for unknown OpenAI models
55+
enc = tiktoken.get_encoding("cl100k_base")
56+
tokens = enc.encode(text)
57+
return len(tokens)
58+
else:
59+
# Use LiteLLM's token counter for other providers
60+
try:
61+
return litellm.token_counter(model=model, text=text)
62+
except Exception:
63+
# Fallback to approximate counting (4 chars per token)
64+
return len(text) // 4
2865

29-
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
66+
def ChatGPT_API_with_finish_reason(model, prompt, api_key=None, chat_history=None):
67+
"""
68+
Synchronous chat completion API with finish reason tracking.
69+
Uses LiteLLM to support multiple LLM providers.
70+
71+
Args:
72+
model: Model string (e.g., "gpt-4o", "claude-3-opus-20240229", "gemini/gemini-pro")
73+
prompt: User prompt
74+
api_key: API key (optional, uses environment variables if not provided)
75+
chat_history: Previous conversation history
76+
77+
Returns:
78+
Tuple of (response_content, finish_reason)
79+
"""
3080
max_retries = 10
31-
client = openai.OpenAI(api_key=api_key)
81+
82+
# Build messages list
83+
if chat_history:
84+
messages = chat_history.copy()
85+
messages.append({"role": "user", "content": prompt})
86+
else:
87+
messages = [{"role": "user", "content": prompt}]
88+
3289
for i in range(max_retries):
3390
try:
34-
if chat_history:
35-
messages = chat_history
36-
messages.append({"role": "user", "content": prompt})
37-
else:
38-
messages = [{"role": "user", "content": prompt}]
39-
40-
response = client.chat.completions.create(
91+
response = litellm.completion(
4192
model=model,
4293
messages=messages,
4394
temperature=0,
95+
api_key=api_key,
4496
)
4597
if response.choices[0].finish_reason == "length":
4698
return response.choices[0].message.content, "max_output_reached"
@@ -51,53 +103,80 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_
51103
print('************* Retrying *************')
52104
logging.error(f"Error: {e}")
53105
if i < max_retries - 1:
54-
time.sleep(1) # Wait for 1秒 before retrying
106+
time.sleep(1) # Wait for 1s before retrying
55107
else:
56108
logging.error('Max retries reached for prompt: ' + prompt)
57-
return "Error"
109+
return "Error", "error"
58110

59111

60112

61-
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
113+
def ChatGPT_API(model, prompt, api_key=None, chat_history=None):
114+
"""
115+
Synchronous chat completion API.
116+
Uses LiteLLM to support multiple LLM providers.
117+
118+
Args:
119+
model: Model string (e.g., "gpt-4o", "claude-3-opus-20240229", "gemini/gemini-pro")
120+
prompt: User prompt
121+
api_key: API key (optional, uses environment variables if not provided)
122+
chat_history: Previous conversation history
123+
124+
Returns:
125+
Response content string
126+
"""
62127
max_retries = 10
63-
client = openai.OpenAI(api_key=api_key)
128+
129+
# Build messages list
130+
if chat_history:
131+
messages = chat_history.copy()
132+
messages.append({"role": "user", "content": prompt})
133+
else:
134+
messages = [{"role": "user", "content": prompt}]
135+
64136
for i in range(max_retries):
65137
try:
66-
if chat_history:
67-
messages = chat_history
68-
messages.append({"role": "user", "content": prompt})
69-
else:
70-
messages = [{"role": "user", "content": prompt}]
71-
72-
response = client.chat.completions.create(
138+
response = litellm.completion(
73139
model=model,
74140
messages=messages,
75141
temperature=0,
142+
api_key=api_key,
76143
)
77-
78144
return response.choices[0].message.content
79145
except Exception as e:
80146
print('************* Retrying *************')
81147
logging.error(f"Error: {e}")
82148
if i < max_retries - 1:
83-
time.sleep(1) # Wait for 1秒 before retrying
149+
time.sleep(1) # Wait for 1s before retrying
84150
else:
85151
logging.error('Max retries reached for prompt: ' + prompt)
86152
return "Error"
87153

88154

89-
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
155+
async def ChatGPT_API_async(model, prompt, api_key=None):
156+
"""
157+
Asynchronous chat completion API.
158+
Uses LiteLLM to support multiple LLM providers.
159+
160+
Args:
161+
model: Model string (e.g., "gpt-4o", "claude-3-opus-20240229", "gemini/gemini-pro")
162+
prompt: User prompt
163+
api_key: API key (optional, uses environment variables if not provided)
164+
165+
Returns:
166+
Response content string
167+
"""
90168
max_retries = 10
91169
messages = [{"role": "user", "content": prompt}]
170+
92171
for i in range(max_retries):
93172
try:
94-
async with openai.AsyncOpenAI(api_key=api_key) as client:
95-
response = await client.chat.completions.create(
96-
model=model,
97-
messages=messages,
98-
temperature=0,
99-
)
100-
return response.choices[0].message.content
173+
response = await litellm.acompletion(
174+
model=model,
175+
messages=messages,
176+
temperature=0,
177+
api_key=api_key,
178+
)
179+
return response.choices[0].message.content
101180
except Exception as e:
102181
print('************* Retrying *************')
103182
logging.error(f"Error: {e}")

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
openai==1.101.0
1+
litellm>=1.0.0
2+
openai>=1.0.0
23
pymupdf==1.26.4
34
PyPDF2==3.0.1
45
python-dotenv==1.1.0

0 commit comments

Comments
 (0)